import  torch

if __name__ == '__main__':
    x = torch.tensor([1.2, 2.1], dtype=torch.float32)
    print(f"type x: {type(x)}")
    print(f"type x: {x.dtype}")

    x = x.to(torch.float16)
    print(f"type x: {x.dtype}")

    tensor1 = torch.randn(2, 3)
    tensor2 = torch.randn(2, 3)
    torch.matmul(tensor1, tensor2)
